#!/usr/bin/python
# -*- coding: utf-8 -*-
# Copyright (c) 2013 Thomas M. Breuel; see accompanying LICENSE file

from pylab import *
import glob,re,os
import ocrolib
from ocrolib import fvariant
from ocrolib.lattice import Lattice

# These are replacements that happen in the ground truth
# prior to any alignment.  FIXME make this configurable.

# The defaults replace all quotation marks with their
# typewriter single quotation marks, mostly because the
# markup is so inconsistent.

# The double quotation marks will later be treated just
# like a ligature, with '"' as the automatically
# inserted ligature character.

# Note that we also add "''" as a potential ligature
# back in again to the recognizer, so that double quotes
# that can't be segmented can still be recognized.

replacements = [
    (u'"',u"''"),     # typewriter double quote
    (u"`",u"'"),      # grave accent
    (u"\u00b4",u"'"),   # acute accent
    (u"\u2018",u"'"), # left single quotation mark
    (u"\u2019",u"'"), # right single quotation mark
    (u"\u201c",u"''"), # left double quotation mark
    (u"\u201d",u"''"), # right double quotation mark
]

def normalize(s):
    for m,r in replacements:
        s = re.sub(m,r,s)
    return s

class Path:
    """A typical state in a dynamic programming problem."""
    def __init__(self,cost=0.0,state=-1,pos=-1,sequence=[],path=[]):
        self.cost = cost
        self.state = state # current state in lattice
        self.pos = pos # current position in string
        self.sequence = sequence # sequence of edges
        self.path = path # sequence of labels
    def __repr__(self):
        return "<Path %.2f %d '%s'>"%(self.cost,self.state,self.path)
    def __str__(self):
        return self.__repr__()
    def __cmp__(self,other):
        return cmp((self.cost,self.state,self.path),(other.cost,other.state,other.path))


def expand(state,gt,mismatch_cost=50,ins_cost=50,verbose=0):
    """Expand the state by another step."""
    result = []
    edges = lattice.edges[state.state]
    rest = gt[state.pos:]
    done = set()
    for e in edges:
        # if verbose: print "gt",state.pos,gt[state.pos],"cls",e.cls,e
        assert e.start==state.state

        # add a mismatch transition with the mismatch cost
        if (e.start,e.stop) not in done and len(e.cls)>0:
            p = Path(cost = state.cost + len(e.cls)*mismatch_cost,
                     state = e.stop,
                     pos = state.pos + len(e.cls),
                     path = state.path + [rest[:len(e.cls)]],
                     sequence = state.sequence + [e])
            result.append(p)
            done.add((e.start,e.stop))

        # for the defined set of ligatures, add low-cost ligature
        # transitions if the edge represents one segment
        matches = [l for l in ligatures+lenients if rest.startswith(l)]
        if e.seg[1]-e.seg[0]<args.maxligskip and e.seg[0]!=0 and matches!=[]:
            for l in matches:
                cost = args.ligcost if l in ligatures else arg.lenientcost
                p = Path(cost = state.cost + cost,
                         state = e.stop,
                         pos = state.pos + len(l),
                         path = state.path + [rest[:len(l)]],
                         sequence = state.sequence + [e])
                result.append(p)

        # add a regular transition if the class matches the rest
        # of the ground truth
        if rest.startswith(e.cls):
            p = Path(cost = state.cost + e.cost,
                     state = e.stop,
                     pos = state.pos + len(e.cls),
                     path = state.path + [e.cls],
                     sequence = state.sequence + [e])
        result.append(p)
    if 0:
        p = Path(cost = state.cost + ins_cost,
                 state = e.stop,
                 pos = state.pos,
                 path = state.path + [""],
                 sequence = state.sequence + [e])
        result.append(p)
    return result

def eliminate_duplicates_and_sort(paths):
    # sort by cost
    paths = sorted(paths)
    # keep track of the best
    result = {}
    for p in paths:
        suffix = "".join(p.path)
        if suffix in result: continue
        result[suffix] = p
    return sorted(result.values())

def search(lattice,gt,accept=None,verbose=0,Verbose=0,beam=100,**kw):
    """Search the lattice for the best path corresponding to the
    ground truth, using a dynamic programming algorithm."""
    global table
    initial = Path(cost=0.0,state=lattice.startState(),pos=0)
    nstates = lattice.lastState()+1
    table = [[] for i in range(nstates)]
    table[initial.state] = [initial]
    for i in range(nstates):
        if lattice.isAccept(i): break
        if len(table[i])==0: continue
        table[i] = eliminate_duplicates_and_sort(table[i])
        if args.debugsearch: print table[i][0]
        if verbose>0: print i,table[i][0]
        for rank,s in enumerate(table[i][:beam]):
            expanded = expand(s,gt,verbose=0,**kw)
            for e in expanded:
                if Verbose: print "    ",e
                table[e.state].append(e)
    result = eliminate_duplicates_and_sort(table[i])
    if args.debugsearch: print "final"; print result[0]
    return result

import argparse

parser = argparse.ArgumentParser("""
Aligns recognizer output with ground truth in order to obtain character
training data.

Inputs: line.lattice, line.gt.txt 
Outputs: line.cseg.png, line.aligned
""")

parser.add_argument('--build',default=None,help="build and write a language model")
parser.add_argument('--ngraph',type=int,default=4,help="order of the language model")
parser.add_argument('--sample',default=None,type=int,help="sample from the language model")
parser.add_argument('--slength',default=70,type=int,help="length of the sampled strings")
parser.add_argument('-l','--lmodel',default=ocrolib.default.ngraphs,help="the language model")
parser.add_argument('-L','--lweight',default=0.5,type=float,help="language model weight")
parser.add_argument('-C','--cweight',default=1.0,type=float,help="character weight")
parser.add_argument('-B','--beam',default=10,type=int,help="beam width")
parser.add_argument('-W','--maxws',default=5.0,type=float,help="max whitespace cost")
parser.add_argument('-M','--maxcost',default=20.0,type=float,help="max cost")
parser.add_argument('-X','--mismatch',default=30.0,type=float,help="mismatch cost")
parser.add_argument('-T','--thresh',default=0.5,type=float,help="below this cost, ignore language model")
parser.add_argument('-v','--verbose',action="store_true")
parser.add_argument('-V','--Verbose',action="store_true")
parser.add_argument('-q','--quiet',action="store_true",help="don't output each line")

# These help align specific kinds of characters that may not be in the training set.
# use it, for example, if you want to train a Latin alphabet with a few extra characters
# or to add new ligatures.

default_ligatures = """fi ff ffi fl ft ck 
li th Th re ri ry rt rn rm as th oo 00 000 po pe ''
"""

parser.add_argument('--ligatures',default=default_ligatures,help="optional ligatures")
parser.add_argument('--ligcost',type=float,default=3.0,help="cost of inserting a ligature instead of recognizing individual characters")
parser.add_argument('--lenients',default="_ ~",help="character sequences matched 'leniently'")
parser.add_argument('--lenientcost',type=float,default=1.0,help="cost of inserting a leniently matched character")
parser.add_argument('--maxligskip',type=int,default=5,help="max number of segments allowed to participate in inferred ligatures")

parser.add_argument('--debugpath',action="store_true",help="print the final aligned path")
parser.add_argument('--debugsearch',action="store_true",help="print the states explored during the search")

parser.add_argument('files',nargs='*')
args = parser.parse_args()
files = args.files

ligatures = args.ligatures.split()
lenients = args.lenients.split()

fnames = []
for pattern in args.files:
    l = sorted(glob.glob(pattern))
    for f in l:
        assert ".lattice" in f,"all files must end with .lattice"
        base,_ = ocrolib.allsplitext(f)
        if not os.path.exists(base+".gt.txt"):
            print f,": no ground truth, skipping"
        else:
            fnames += [f]

if len(fnames)==0:
    parser.print_help()
    sys.exit(0)

print "processing",len(fnames),"files"
for fname in fnames:
    if not args.quiet: print fname,"=ALIGNED=",
    base,_ = ocrolib.allsplitext(fname)
    if not os.path.exists(fvariant(base,"txt","gt")):
        print "ERROR","no ground truth, skipping"
        continue
    with open(fvariant(base,"txt","gt")) as stream:
        gt = stream.read()
        if gt[-1]=="\n": gt = gt[:-1]
        gt = re.sub(r'\s+$','',gt)
    gt = normalize(gt)
    if args.verbose:
        print "gt=",gt
    lattice = Lattice(maxws=args.maxws,maxcost=args.maxcost,mismatch=args.mismatch)
    lattice.readLattice(fname)
    result = search(lattice,gt,verbose=args.verbose,Verbose=args.Verbose)
    path = result[0]
    gt = []
    mapping = zeros(10000,'i')
    for i,e in enumerate(path.sequence):
        c = path.path[i]
        if args.debugpath: print i,"aligned",c,"edge",e
        if c=="": continue
        gt.append(c)
        if c==" ":
            assert e.seg[1]==0
        else:
            for s in range(e.seg[0],e.seg[1]+1):
                if e.seg[0]==0: continue
                if args.debugpath: print "   mapping",s,"->",len(gt)
                mapping[s] = len(gt)
    if os.path.exists(fvariant(fname,"rseg")):
        seg = ocrolib.read_line_segmentation(fvariant(fname,"rseg"))
        seg = mapping[seg]
        ocrolib.write_line_segmentation(fvariant(fname,"cseg"),seg)
    else:
        print "ERROR",fvariant(base,"rseg"),": not found"
    gt = ocrolib.gt_implode(gt)
    ocrolib.write_text(fvariant(fname,"aligned"),gt)
    if not args.quiet: print "%6.1f  %s"%(path.cost,gt)
